{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn.functional as F\n",
    "import models\n",
    "import matplotlib.pyplot as plt\n",
    "from helper import *\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = ['sky', 'building', 'pole', 'road', 'pavement', 'tree', 'signsymbol', 'fence', 'car', 'pedestrian', 'bicyclist', 'unlabelled']\n",
    "DATA_DIR = '../data/CamVid/'\n",
    "x_train_dir = os.path.join(DATA_DIR, 'trainsmall')\n",
    "y_train_dir = os.path.join(DATA_DIR, 'trainsmallannot')\n",
    "\n",
    "x_valid_dir = os.path.join(DATA_DIR, 'val')\n",
    "y_valid_dir = os.path.join(DATA_DIR, 'valannot')\n",
    "\n",
    "x_test_dir = os.path.join(DATA_DIR, 'test')\n",
    "y_test_dir = os.path.join(DATA_DIR, 'testannot')\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean = [0.41189489566336, 0.4251328133025, 0.4326707089857], std = [0.27413549931506, 0.28506257482912, 0.28284674400252])\n",
    "])\n",
    "\n",
    "train_dataset = CamVid(\n",
    "    x_train_dir,\n",
    "    y_train_dir,\n",
    "    classes = classes,\n",
    "    transform = transform\n",
    ")\n",
    "\n",
    "valid_dataset = CamVid(\n",
    "    x_valid_dir,\n",
    "    y_valid_dir,\n",
    "    classes = classes, \n",
    "    transform = transform\n",
    ")\n",
    "\n",
    "test_dataset = CamVid(\n",
    "    x_test_dir,\n",
    "    y_test_dir,\n",
    "    classes = classes, \n",
    "    transform = transform\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, drop_last = True)\n",
    "valloader = DataLoader(valid_dataset, batch_size = 1, shuffle = False)\n",
    "testloader = DataLoader(test_dataset, batch_size = 1, shuffle = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unet = models.unet.Unet('resnet18', classes = 12, encoder_weights = None).cuda()\n",
    "unet.load_state_dict(torch.load('../saved_models/resnet18/pretrained.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_iou(model, dataloader) : \n",
    "    gpu1 = 'cuda:0'\n",
    "    ious = list()\n",
    "    for i, (data, target) in enumerate(dataloader) : \n",
    "        data, target = data.float().to(gpu1), target.long().to(gpu1)\n",
    "        prediction = model(data)\n",
    "        prediction = F.softmax(prediction, dim = 1)\n",
    "        prediction = torch.argmax(prediction, axis = 1).squeeze(1)\n",
    "        ious.append(iou(target, prediction, num_classes = 11))\n",
    "        \n",
    "    return (sum(ious) / len(ious))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5357400573392819\n",
      "0.36408891822227735\n"
     ]
    }
   ],
   "source": [
    "print(mean_iou(unet, valloader))\n",
    "print(mean_iou(unet, testloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unet = models.unet.Unet('resnet34', classes = 12, encoder_weights = None).cuda()\n",
    "unet.load_state_dict(torch.load('../saved_models/resnet34/pretrained.pt'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5547654143631603\n",
      "0.37584598449861767\n"
     ]
    }
   ],
   "source": [
    "print(mean_iou(unet, valloader))\n",
    "print(mean_iou(unet, testloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pixelwise_acc(mask1, mask2) :\n",
    "    equals = (mask1 == mask2).sum().item()\n",
    "    return equals / (mask1.shape[0] * mask1.shape[1] * mask1.shape[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f1d6c506e10>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAK0AAAD8CAYAAAAFfSQRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAYVklEQVR4nO2de5AV1Z3HPz/nIQ8jI/IaGHVCYLNONilBxAfGskSqEhIxtaUbLde4WxrKV0pjKhE22f0nqUXz0DxXgrIpNcZH0IroGi0kWq6mgqAxZGGCDoiG1wCJgCJmHDn7R5++9Nzpe2/fe7tvn+77+1Tdut2nT3f/pu93fv0753SfnxhjUJQscVTaBihKtaholcyholUyh4pWyRwqWiVzqGiVzJGIaEXkUyKySUT6RGRREudQmheJu59WRFqAV4F5wDZgLXCpMWZjrCdSmpYkPO1soM8Ys8UYMwA8AFyYwHmUJqU1gWNOAf4cWN8GnF5cSUQWAgsBWqTt1Jbu8TB4FLQertuAETs+qPsYtWL+NpDaueNCjm5P9fwH/ta/1xgzvtT2JEQrIWXDYhBjzDJgGcCYEZPMuP+8NjYD/u4b+2I7VrUMbtma2rnjorWrO9XzP7n5u2+U255EeLANOCGw3gXsKLfDe5NbYjXg1W91xHo8xS2SEO1aYLqIfFhE2oFLgJUJnEdpUmIXrTFmELgeeAroBR4yxmyI+zzlSDM8aJ3andq54yAL9icR02KMeQJ4Iolju04eYlrX0RExpUAWvCw4Ito4u6jSDA2yTFYEC46IFjyxqeDSIUuCBYdE61NKuIf7RxQ+1e6rlCZrgoWEGmJxUUqgh/tH8Pc/2jWkzJW+2dap3ZlpjGVRsOC4aMMoFquPetnqGdyyNZPCdVK0WRdglrxtFnEuplUaQxY9rI+zonUlRs0zWRWus6LNeojgOlkOX5wVbdZx3Yu5bl85VLRK5lDRJoir3sxVu6LiZJdXXnAlbsy6SItRT5sgeROLK6holcyh4UGOyaunV0+rZA71tDmg2KNm9UGYqKinzSF5FiyoaGvi1W91OPNsRN4FGoaKtkqCYnVFuM2GirYKwkRaSbitU7sT84bN6GVBRatkEBVtRmlWLwva5RULYSGCPg+cHOppq0CF6AbqaRuEfzuv58mvZg4JgqinTYi4usP83gcV7BHU09aIHyok1VerIi2NetoqKZ5zrFycm/bgQ9rnTwoVbY6ppuGYJYFXFK2InCAiz4hIr4hsEJEbbPlYEVklIq/Z7+NsuYjID23iu/UiMjPpPyJtSomj2t6GWuLXesTmP0ORJcFCNE87CHzFGHMycAZwnYj0AIuA1caY6cBquw7waWC6/SwE7ojdageJKtDB805N5PzVCq+WIWlXqChaY8xOY8zLdvltvDwKU/AS2t1tq90NfM4uXwjcYzx+B3SISGfsljtIFOGO2PqX2M/riy0roquXqmJaEekGZgBrgInGmJ3gCRuYYKuFJb+bUq+hjearq2pLyBMUblQRRQkJip8uC97W/XM2y+BHZNGKyDHAw8CNxpgD5aqGlA1LficiC0VknYisG/jg3ahmNARfsHEItxL1dm3FKdisiD6SaEWkDU+w9xljHrHF/f5t337vtuWRkt8ZY5YZY2YZY2a1t4yq1X5nKTcdvwt9sMW2ZUWwEGFwQUQEWA70GmNuC2xaCVwB3GK/Hw2UXy8iD+DlxN3vhxFZ5KurVvKdeQvqPkaQc0ce5hu7P86DG0/l8z0v8ft//EjFY4QNZtQrtCwJNUgUTzsHuBw4T0ResZ/5eGKdJyKvAfPsOnj5w7YAfcCdQHxJbzNIWIjxjd0fB+DzPS9Vfbw8JVT505cm1bRfRU9rjHme8DgVYG5IfQNcV5M1DlBrHBt2nKge+r3u40N7FWY8srngjX18r3zn6ffAKuq+C2QRHRFLkOA/wLkjD4fWWfW9s3lnSjt753Syd04nMx7ZXPjAcG/sl0c5Z602hx0jrn/mIKXyZ1RCH5gp4jvzFsT2AxUfyxfu029766u+d/aQ+vO+8nyk41YSbrVEFWkc8X0cqGgbyOk3X1N2+6rvnR1ZuD7FQqpWVEl40KQRLwRNlzEjJpmzui5P24wh+D9mLZ7ljYsnh5aP3lHdtS4n4F/d/0l+vHBp7D0bUUja2z65+bsvGWNmldquMW0Jav1hSgk2Tvyw4vplV/PGxZOHfaqh2r9TwwPHifoDRRFKtV4Whse8UY8XZs9Jvxw2vlMW/293MXxQ0ZahVMOjGm9Wi1jrJXjOg5O93spSNoeJuThGdk24KtqINOK2nwRhAg7yxsWTuX7Z1fx44VLAjdt/JVS0FahHrGl42XKE2eML+fplV3sFFw/f79yRh3n2kNf8caHbSxtiCeGaYEsxeoeJZKvfx1wQd4qop8Xzpn5sF/Sstf5AWRFskGKbg6HEP/zAe3zk/274L6D83afaBl8tNF0/bdKxaRYFG4WweLgaqhFzpX7apvC0jWpE5VWwULlBV4lqei8qkWvRNrLFn2fBFlOuQVctob/RLcOLguRStI3unmomwZaiXEwcN5kWrQt9pyrYcOoNJ8qRSdG6IFYlOnELODOidVGo6mWrJ44wwnnRuihWUMHGRS1e2FnRuipWUMEmRdTr6uQwrgpWKYdzolXBKpVwTrSuooJ1B2diWvWwSlSc8LQDx7WlbUJJVLDu4YRoFaUaVLRlUC/rJiraEqhg3UVFG4IK1m1UtEWoYN1HRatkDmf6adNGPWx2UE+LCjZrVJPdpkVEfi8ij9v1D4vIGpux8UERabflR9v1Pru9OxnT40EFmz2q8bQ34CW+87kVuN1mbHwLuNKWXwm8ZYyZBtxu6zmJCjabRE3J1AV8BrjLrgtwHrDCVinO2OhnclwBzLX1nUIFm12ietrvA18D/MQBxwP7jDGDdj2YlbGQsdFu32/rDyGY/G7w3YM1mq80I1GykH8W2G2MCWasKJeVMVLGxmDyu9ZRoyMZGwdR565S3CVKl9ccYIHNHTYCOBbP83aISKv1psGsjH7Gxm0i0gqMAf4au+U1oGLNB1GykC82xnQZY7qBS4DfGGMuA54BLrLVijM2XmGXL7L1U1eLCjY/1NNPezNwk4j04cWsy235cuB4W34TsKg+ExVlKFWNiBljngWetctbgNkhdd4jdGre9FAvmy9yPyKmgs0fuRatCjaf5Fa0Ktj8kkvRqmDzTe5Eq4LNP7kSrQq2OciNaFWwzUMuRKuCbS4yL1oVbPORadGqYJsTJ5LficjbwKa07ajAOGBv2kZEIAt2VrLxJGPM+FIbXXkbd1O5DH0uICLrXLcRsmFnvTZmOjxQmhMVrZI5XBHtsrQNiEAWbIRs2FmXjU40xBSlGlzxtIoSGRWtkjlSF62IfEpENtlplFJ7n0xEThCRZ0SkV0Q2iMgNtnysiKyy0z+tEpHjbLmIyA+t3etFZGYDbXV6iioR6RCRFSLyJ3s9z4z1OhpjUvsALcBmYCrQDvwB6EnJlk5gpl3+EPAq0AN8G1hkyxcBt9rl+cCv8eZ5OANY00BbbwJ+ATxu1x8CLrHLS4Fr7PK1wFK7fAnwYIPsuxu4yi63Ax1xXse0RXsm8FRgfTGwOE2bArY8CszDG6nrtGWdeAMhAD8FLg3UL9RL2K4uYDXetFSP2x97L9BafE2Bp4Az7XKrrScJ23cs8HrxeeK8jmmHB4UplCzB6ZVSw95GZwBrgInGmJ0A9nuCrZaW7bFPURUzU4E9wM9sCHOXiIwmxuuYtmgjTaHUSETkGOBh4EZjzIFyVUPKErU9qSmqYqYVmAncYYyZARyk/NwXVduYtmj9KZR8gtMrNRwRacMT7H3GmEdscb+IdNrtncBuW56G7f4UVVuBB/BChMIUVSF2FGxs4BRV24Btxpg1dn0Fnohju45pi3YtMN22ftvxGgsr0zDETke6HOg1xtwW2BSc5ql4+qcv2NbvGcB+//aXFCYDU1QZY3YBfxaRj9qiucBG4ryODjR45uO11DcDX0/RjrPxbkvrgVfsZz5eDLgaeM1+j7X1BfiJtfuPwKwG23suR3oPpgIvAn3AL4GjbfkIu95nt09tkG2nAOvstfwVcFyc11GHcZXMkUh44MqAgZJPYve0ItKCd7ufhxdkr8Xrh9sY64mUpiUJTzsb6DPGbDHGDOC1ci9M4DxKk5LE6zZhncWnF1cSkYXAQoAWaTt1dNtYAP42rg3TVrv3bzsgtBx8v+b968UMDKR27riQ9vbQctPeggx8kPj5Dwz07zUNfkcscs4F7MPAY46eZM6achkAfVdNYWD8YHH1yPQs6fdGulNicOub6Z08JlqnnBhaPtA1lvZtyWciePL1294otz2J8MCpAQMlPhoh2CgkIVpnBgzSoLU73EvlgYGusWmbACQQHhhjBkXkerwnjFqA/zbGbIi6fz2hgVI/WfinS2TeA2PME8ATSRzbdfIQ07pO2s8eKA5RycvmOaZVMkgWwgIfV6ZFKtCzpH9Y2b7Zkzk46SgOTTCMW3+k92zvJ4Rx6w075ibfd5hnsiRYcFC0YXS8uCO067XjxaHfSvVkTbCg4UHsZEkEWbI1iIpWyRwq2gTIqgfLCipaJXOoaJXMoaJVMoeKNiE0rk0OFa2SOVS0CfKh/x2XtgmhtHafmOk7QSZGxLLK2590IzNSlgUahnraBMmbWFxBRZtz8viPo6JNCBdeTcmjYEFFq2QQFW2OCHrWuL2sC3cOH+09yAFJitXHlVdtQD1t1Wxf0JW2CUNoVNyqnjaj7Js9mf0fe5/9H5tYKBuzoY0pK7eV3e/w2adw1POvJGLT4NY3c9vgKoWKNmFau08kqdm9GilWDQ+UzOFSeKCijQH/Bx3oGhv647rkpfKAhgdVEPaquhfjtgETA6UTmXZvstONNlscG0Q9bRWM2dCWtgkK6mmrYsrKbUwpmv9x4+KJoXWLQwLfM9Y611dr94mFfZvZy4KKNhHa98R7WX2RNrtYfTQ8SICB8YP0XTU8vWs1XtZ/UDvoYRUPFW0DqdVTqocdioq2TnqW9Bc+QUbuDks9Af+x5eWKx4xDpH1XTSkZb2edisGXiJwA3ANMwkvXvswY8wMRGQs8CHQDW4F/Msa8ZXPM/gAvRee7wL8YYyr/Ujnj0ITaMvTE5VUrzai+fUEX+z92pFsubLZKV4niaQeBrxhjTgbOAK4TkR68dOirjTHT8XKd+pkZPw1Mt5+FwB2xW+04Yza0Me2u7aHb5owIv+S1vGy4cfHE3HrTclT0tMbLCL3TLr8tIr14ucIuxEsqDHA38Cxwsy2/x3ipIH8nIh0i0mkSztDtGr6YkvRg5Y6dJc9ZLVXFtCLSDcwA1gATfSHa7wm2Wljyu2FNaRFZKCLrRGTdwAfvVm95wpzzWG/VXqxnST/T7n2/5FNfYcdr7T6Rp3ZUfgLMtUci0yRyh6KIHAM8DNxojDngha7hVUPKKia/i2pHIzlmwsGa9w2OnvkJ/XqW9MMFQ+v9z2+jZavyH4nMsweNSiRPKyJteIK9zxjziC3uF5FOu70T2G3LM5/8zvdq/zxtbdX7+iNh3uiZ53Gn3bU9dAg4iodVhhOl90CA5UCvMea2wKaVwBXALfb70UD59SLyAF5O3P1Zi2cv/eKqwvI5j/Xy3AUn133MYMhQi1h7lvSzb/bkuu0I2jP+5cakDY2bKOHBHOBy4I8i4l/tf8MT60MiciXwJnCx3fYEXndXH16X17/GanFGOeexXs4evYl6usY7Xkz2hrV9QVfFtzBcIErvwfOEx6kAc0PqG+C6Ou1KjZH3xtMo9D30yHvf5bQOLz+x39110ebzAVjxkae5aPP5rPjI04XY1t92Wscb3Hz8a7zw3mHmjDiKz5y1YNg5BrrG0nd5W2xx7qVfXMVzK09m42K3Y2d9YCZhfMECnHLLtQBMu/hVAObceLX3zTQAFvz76kLdtftO4qJ9JwFwO0f+mQ5dPgrwBHv+HS9wPrBy9lwOXHqAE7/8Duc81svT18yp+rbvH3/7gi6uPmsVz1F/SJQUOoybAjt/NI2dP5o2rHzlN4fduErSvu2vNtyAA5ceGLatGvbNnlz45wrG80HOv+OF2O5C9aKiTRC/F+L+O+fx877TABi963DJ+ns/UbIbsUDfVVOG9Nk+d8HJHHv/sUPqVNtgG7XjvUj1gneNNNHwoIj9i7tYu6S+H8gTVS/gCRag5ZmOsoIFGLfesHP9EQ/c+aW+YXUGxg8ycncb2xd08fzBjxYEfOKX3ynUidpgi+qRz3msN1K9RqGetojiH7Ka7q7tC7qGeNcglQQblWMmHGT0rsPDjuefe+lvz4385mxYvaW/PbfsPi4IWD1tnSQ5vBoW9wYDgZXfnMtoiv8Z2tgzsw1mjhq27/iX3w31rsGyniX98NjQ7Tcf/xq3/mV6YTntRpqKtgYGusayJ0QUpYjLy9ZyroOTjtxM98wcVRDz+JejNarevP2YYWUDXekOSqhoQzh0+ahh3qZWj9pIwUY5vy/iPTNHMXrXYfbM9P+u8rf95y442YnQAFS0Fdk3e/IQb1UNaQs2jKBNwb/r/jvncWiCYeQC4f47u8COZbyz+31eeG/o35H20K82xEIIdhnlSbDFFDfowl4RGrOhbUgvxSm3XMu+2ZNjfQ6iWprW0wZbzu9OHhEqzuIegKhkQbBBSnnfYn7edxoDEwy+rzsYEjIFG3tJxb7iPSqQLmOOnmTOmnJZQ85Vz+0+ClkTbFTiuGZRH8Z58vXbXjLGzCq1vSk8baOe+s+rYKF8r0RUSv0Oo3cdruoJttyKNmmPWkyeBRtGHCIO7uuHGqN3HYbXy9fPnWir7UONg2YTbBhR4+JKRNk3F6JN86U/Fexw4vTCYWROtPpWavbwRRyXeDMhWleFql62OsKuVy1Cdlq0rooVVLBxUUss7KRoG93yrxYVbDJEva7OidZl7woqWBdwyp2pYJUoOCNaFawSFSdE+36H21ljVLBu4YRoXUYF6x4q2jKoYN1ERVsCFay7qGiVzKGiDUG9rNuoaJXMoaItQr2s+6hoA6hgs0Fk0YpIi4j8XkQet+sfFpE1IvKaiDwoIu22/Gi73me3dydjeryoYLNDNZ72BoZOQ3IrcLtNfvcWcKUtvxJ4yxgzDW8+4FvjMDRJVLDZImp2my7gM8Bddl2A84AVtsrdwOfs8oV2Hbt9rpTJ35Q2KtjsEdXTfh/4GhSm6Dse2GeM8ROwBhPcFZLf2e37bf0hBJPfDR6qPV9XPahgs0lF0YrIZ4HdxpiXgsUhVU2EbUcKjFlmjJlljJnVOnJ0JGPjRAWbXaKmZFogIvOBEXhTpH4f6BCRVutNgwnu/OR320SkFRgDOJWsSgWbbSp6WmPMYmNMlzGmG7gE+I0x5jLgGeAiW604+d0VdvkiWz/9uZcsKtjsU08/7c3ATSLShxezLrfly4HjbflNwKL6TIwPFWw+qOodMWPMs8CzdnkLMDukznscyd7oDCrY/NAUI2Iq2HyRe9GqYPNHrkWrgs0nuRWtCja/ODdZR72oWPNPbj2tkl9yJVr1ss1BbkSrgm0eciFaFWxzkQvRKs1FpnsP1MM2J04kvxORt4FNadtRgXHA3rSNiEAW7Kxk40nGmPGlNrriaTeVy9DnAiKyznUbIRt21mujxrRK5lDRKpnDFdEuS9uACGTBRsiGnXXZ6ERDTFGqwRVPqyiRUdEqmSN10YrIp0Rkk537K7WXIEXkBBF5RkR6RWSDiNxgy8eKyCo7Z9kqETnOlouI/NDavV5EZjbQVqfnVRORDhFZISJ/stfzzFivozEmtQ/QAmwGpgLtwB+AnpRs6QRm2uUPAa8CPcC3gUW2fBFwq12eD/wab3KSM4A1DbT1JuAXwON2/SHgEru8FLjGLl8LLLXLlwAPNsi+u4Gr7HI70BHndUxbtGcCTwXWFwOL07QpYMujwDy8kbpOW9aJNxAC8FPg0kD9Qr2E7eoCVuPNpfa4/bH3Aq3F1xR4CjjTLrfaepKwfccCrxefJ87rmHZ4UJj3yxKcEyw17G10BrAGmGiM2QlgvyfYamnZHvu8ajEzFdgD/MyGMHeJyGhivI5pizbSvF+NRESOAR4GbjTGHChXNaQsUduTmlctZlqBmcAdxpgZwEHKT9hStY1pi9af98snOCdYwxGRNjzB3meMecQW94tIp93eCey25WnY7s+rthV4AC9EKMyrFmJHwcYGzqu2DdhmjFlj11fgiTi265i2aNcC023rtx2vsbAyDUPsHLrLgV5jzG2BTcG5yYrnLPuCbf2eAez3b39JYTIwr5oxZhfwZxH5qC2aC2wkzuvoQINnPl5LfTPw9RTtOBvvtrQeeMV+5uPFgKuB1+z3WFtfgJ9Yu/8IzGqwvedypPdgKvAi0Af8Ejjalo+w6312+9QG2XYKsM5ey18Bx8V5HXUYV8kcaYcHilI1Klolc6holcyholUyh4pWyRwqWiVzqGiVzPH/J8PQnKrEVVYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "img, mask = next(iter(trainloader))\n",
    "img, mask = img.cuda(0), mask.cuda(0)\n",
    "out = unet(img)\n",
    "out = F.softmax(out, dim = 1)\n",
    "y = out.cpu().detach().numpy()\n",
    "y = np.argmax(y, axis = 1)\n",
    "mask_ = mask.cpu().detach().numpy()\n",
    "\n",
    "fig, axes = plt.subplots(2, 1)\n",
    "axes[0].imshow(y[0])\n",
    "axes[1].imshow(mask_[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([8, 480, 640]) torch.Size([8, 480, 640])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9293461100260416"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out2 = torch.argmax(out, dim = 1).squeeze(1)\n",
    "print(out2.shape, mask.shape)\n",
    "pixelwise_acc(out2, mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dice_coeff(mask1, mask2, smooth = 1e-6, num_classes = 19) : \n",
    "    dice = 0\n",
    "    for sem_class in range(num_classes) : \n",
    "        pred_inds = (mask1 == sem_class)\n",
    "        target_inds = (mask2 == sem_class)\n",
    "        intersection = (pred_inds[target_inds]).long().sum().item()\n",
    "        denom = pred_inds.long().sum().item() + target_inds.long().sum().item()\n",
    "        dice += (float(2 * intersection) + smooth) / (float(denom) + smooth)\n",
    "    return dice / num_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_dice_coeff(model, dataloader) : \n",
    "    gpu1 = 'cuda:0'\n",
    "    gpu2 = 'cuda:1'\n",
    "    ious = list()\n",
    "    for i, (data, target) in enumerate(dataloader) : \n",
    "        data, target = data.float().to(gpu1), target.long().to(gpu1)\n",
    "        prediction = model(data)\n",
    "        prediction = F.softmax(prediction, dim = 1)\n",
    "        prediction = torch.argmax(prediction, axis = 1).squeeze(1)\n",
    "        ious.append(dice_coeff(target, prediction, num_classes = 11))\n",
    "        \n",
    "    return (sum(ious) / len(ious))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.6281462205043835"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "avg_dice_coeff(unet, valloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.63176  +-  0.01848 ---- 0.52617  +-  0.00872 ---- 0.35491  +-  0.00422\n",
      "0.62573  +-  0.01292 ---- 0.5477  +-  0.00672 ---- 0.36425  +-  0.00472\n",
      "0.64372  +-  0.0064 ---- 0.54241  +-  0.00485 ---- 0.36209  +-  0.00276\n",
      "0.63487  +-  0.00806 ---- 0.53814  +-  0.01072 ---- 0.37006  +-  0.0038\n",
      "0.65151  +-  0.01434 ---- 0.55949  +-  0.01342 ---- 0.37647  +-  0.00518\n",
      "0.67577  +-  0.00784 ---- 0.60715  +-  0.00842 ---- 0.42734  +-  0.006\n"
     ]
    }
   ],
   "source": [
    "for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26', 'resnet34'] : \n",
    "    train_ious = list()\n",
    "    val_ious = list()\n",
    "    test_ious = list()\n",
    "    unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "    for i in range(3) : \n",
    "        unet.load_state_dict(torch.load('../saved_models/' + model_name + '/pretrained_' + str(i) + '.pt'))\n",
    "        current_val_iou = mean_iou(unet, valloader)\n",
    "        current_train_iou = mean_iou(unet, trainloader)\n",
    "        current_test_iou = mean_iou(unet, testloader)\n",
    "        train_ious.append(current_train_iou)\n",
    "        val_ious.append(current_val_iou)\n",
    "        test_ious.append(current_test_iou)\n",
    "    train_mean = np.mean(train_ious)\n",
    "    train_std_dev = np.std(train_ious)\n",
    "    val_mean = np.mean(val_ious)\n",
    "    val_std_dev = np.std(val_ious)\n",
    "    test_mean = np.mean(test_ious)\n",
    "    test_std_dev = np.std(test_ious)\n",
    "    print(round(train_mean, 5), ' +- ', round(train_std_dev, 5), '----', round(val_mean, 5), ' +- ', round(val_std_dev, 5), '----', round(test_mean, 5), ' +- ', round(test_std_dev, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6888440573380167\n",
      "0.6170405935881389\n",
      "0.4356820135808164\n",
      "0.6608705057992386\n",
      "0.5964584916274905\n",
      "0.42183620589369875\n",
      "0.6773658402202525\n",
      "0.6079386404367901\n",
      "0.4244905592057309\n"
     ]
    }
   ],
   "source": [
    "for i in range(3) : \n",
    "    unet = models.unet.Unet('resnet34', classes = 12, encoder_weights = None).cuda()\n",
    "    unet.load_state_dict(torch.load('../saved_models/resnet34/pretrained_' + str(i) + '.pt'))\n",
    "    current_val_iou = mean_iou(unet, valloader)\n",
    "    current_train_iou = mean_iou(unet, trainloader)\n",
    "    current_test_iou = mean_iou(unet, testloader)\n",
    "    print(current_train_iou)\n",
    "    print(current_val_iou)\n",
    "    print(current_test_iou)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6546765093381575\n",
      "0.5977865289519019\n",
      "0.4214258224376416\n"
     ]
    }
   ],
   "source": [
    "unet = models.unet.Unet('resnet10', classes = 12, encoder_weights = None).cuda()\n",
    "unet.load_state_dict(torch.load('../saved_models/resnet10/classifier/model0.pt'))\n",
    "print(mean_iou(unet, trainloader))\n",
    "print(mean_iou(unet, valloader))\n",
    "print(mean_iou(unet, testloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.6535  +-  0.00065 ---- 0.59959  +-  0.00188 ---- 0.42046  +-  0.00141\n",
      "0.65567  +-  0.00167 ---- 0.61364  +-  0.00061 ---- 0.42784  +-  0.00201\n",
      "0.66328  +-  0.00157 ---- 0.60884  +-  0.00288 ---- 0.42486  +-  0.00078\n",
      "0.6647  +-  0.00109 ---- 0.6082  +-  0.00163 ---- 0.42703  +-  0.00022\n",
      "0.66226  +-  0.0004 ---- 0.60037  +-  0.00139 ---- 0.42656  +-  0.00113\n"
     ]
    }
   ],
   "source": [
    "for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'] : \n",
    "    train_ious = list()\n",
    "    val_ious = list()\n",
    "    test_ious = list()\n",
    "    unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "    for i in range(3) : \n",
    "        unet.load_state_dict(torch.load('../saved_models/' + model_name + '/classifier/model' + str(i) + '.pt'))\n",
    "        current_val_iou = mean_iou(unet, valloader)\n",
    "        current_train_iou = mean_iou(unet, trainloader)\n",
    "        current_test_iou = mean_iou(unet, testloader)\n",
    "        train_ious.append(current_train_iou)\n",
    "        val_ious.append(current_val_iou)\n",
    "        test_ious.append(current_test_iou)\n",
    "    train_mean = np.mean(train_ious)\n",
    "    train_std_dev = np.std(train_ious)\n",
    "    val_mean = np.mean(val_ious)\n",
    "    val_std_dev = np.std(val_ious)\n",
    "    test_mean = np.mean(test_ious)\n",
    "    test_std_dev = np.std(test_ious)\n",
    "    print(round(train_mean, 5), ' +- ', round(train_std_dev, 5), '----', round(val_mean, 5), ' +- ', round(val_std_dev, 5), '----', round(test_mean, 5), ' +- ', round(test_std_dev, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.62542  +-  0.00096 ---- 0.55015  +-  0.00484 ---- 0.38187  +-  0.00239\n",
      "0.61828  +-  0.00688 ---- 0.56153  +-  0.00226 ---- 0.3925  +-  0.00252\n",
      "0.63054  +-  0.00495 ---- 0.56978  +-  0.006 ---- 0.39678  +-  0.00043\n",
      "0.62635  +-  0.00063 ---- 0.56548  +-  0.0072 ---- 0.39863  +-  0.00043\n",
      "0.62772  +-  0.00471 ---- 0.56626  +-  0.00247 ---- 0.39943  +-  0.00244\n"
     ]
    }
   ],
   "source": [
    "for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'] : \n",
    "    train_ious = list()\n",
    "    val_ious = list()\n",
    "    test_ious = list()\n",
    "    unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "    for i in range(3) : \n",
    "        unet.load_state_dict(torch.load('../saved_models/less_data/' + model_name + '/classifier/model' + str(i) + '.pt'))\n",
    "        current_val_iou = mean_iou(unet, valloader)\n",
    "        current_train_iou = mean_iou(unet, trainloader)\n",
    "        current_test_iou = mean_iou(unet, testloader)\n",
    "        train_ious.append(current_train_iou)\n",
    "        val_ious.append(current_val_iou)\n",
    "        test_ious.append(current_test_iou)\n",
    "    train_mean = np.mean(train_ious)\n",
    "    train_std_dev = np.std(train_ious)\n",
    "    val_mean = np.mean(val_ious)\n",
    "    val_std_dev = np.std(val_ious)\n",
    "    test_mean = np.mean(test_ious)\n",
    "    test_std_dev = np.std(test_ious)\n",
    "    print(round(train_mean, 5), ' +- ', round(train_std_dev, 5), '----', round(val_mean, 5), ' +- ', round(val_std_dev, 5), '----', round(test_mean, 5), ' +- ', round(test_std_dev, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for perc in [10, 20, 30, 40] :\n",
    "    x_train_dir = os.path.join(DATA_DIR, 'trainsmall' + str(perc))\n",
    "    y_train_dir = os.path.join(DATA_DIR, 'trainsmallannot' + str(perc))\n",
    "    train_dataset = CamVid(\n",
    "        x_train_dir,\n",
    "        y_train_dir,\n",
    "        classes = classes,\n",
    "        transform = transform\n",
    "    )\n",
    "    trainloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, drop_last = True)\n",
    "\n",
    "    for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'] : \n",
    "        train_ious = list()\n",
    "        val_ious = list()\n",
    "        test_ious = list()\n",
    "        unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "        for i in range(3) : \n",
    "            unet.load_state_dict(torch.load('../saved_models/less_data' + str(perc) + '/' + model_name + '/classifier/model' + str(i) + '.pt'))\n",
    "            current_val_iou = mean_iou(unet, valloader)\n",
    "            current_train_iou = mean_iou(unet, trainloader)\n",
    "            current_test_iou = mean_iou(unet, testloader)\n",
    "            train_ious.append(current_train_iou)\n",
    "            val_ious.append(current_val_iou)\n",
    "            test_ious.append(current_test_iou)\n",
    "        train_mean = np.mean(train_ious)\n",
    "        train_std_dev = np.std(train_ious)\n",
    "        val_mean = np.mean(val_ious)\n",
    "        val_std_dev = np.std(val_ious)\n",
    "        test_mean = np.mean(test_ious)\n",
    "        test_std_dev = np.std(test_ious)\n",
    "        print(round(train_mean, 5), ' +- ', round(train_std_dev, 5), '----', round(val_mean, 5), ' +- ', round(val_std_dev, 5), '----', round(test_mean, 5), ' +- ', round(test_std_dev, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation of Traditional KD with full dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.64853 0.55583 0.37193\n",
      "0.63207 0.55299 0.38704\n",
      "0.61392 0.55576 0.37983\n",
      "0.61879 0.50527 0.3567\n",
      "0.60309 0.51407 0.35582\n"
     ]
    }
   ],
   "source": [
    "x_train_dir = os.path.join(DATA_DIR, 'train')\n",
    "y_train_dir = os.path.join(DATA_DIR, 'trainannot')\n",
    "\n",
    "x_valid_dir = os.path.join(DATA_DIR, 'val')\n",
    "y_valid_dir = os.path.join(DATA_DIR, 'valannot')\n",
    "\n",
    "x_test_dir = os.path.join(DATA_DIR, 'test')\n",
    "y_test_dir = os.path.join(DATA_DIR, 'testannot')\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean = [0.41189489566336, 0.4251328133025, 0.4326707089857], std = [0.27413549931506, 0.28506257482912, 0.28284674400252])\n",
    "])\n",
    "\n",
    "train_dataset = CamVid(\n",
    "    x_train_dir,\n",
    "    y_train_dir,\n",
    "    classes = classes,\n",
    "    transform = transform\n",
    ")\n",
    "\n",
    "valid_dataset = CamVid(\n",
    "    x_valid_dir,\n",
    "    y_valid_dir,\n",
    "    classes = classes, \n",
    "    transform = transform\n",
    ")\n",
    "\n",
    "test_dataset = CamVid(\n",
    "    x_test_dir,\n",
    "    y_test_dir,\n",
    "    classes = classes, \n",
    "    transform = transform\n",
    ")\n",
    "\n",
    "trainloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, drop_last = True)\n",
    "valloader = DataLoader(valid_dataset, batch_size = 1, shuffle = False)\n",
    "testloader = DataLoader(test_dataset, batch_size = 1, shuffle = False)\n",
    "\n",
    "for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'] : \n",
    "    unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "    unet.load_state_dict(torch.load('../saved_models/camvid/trad_kd/' + model_name + '/model0.pt'))\n",
    "    current_val_iou = mean_iou(unet, valloader)\n",
    "    current_train_iou = mean_iou(unet, trainloader)\n",
    "    current_test_iou = mean_iou(unet, testloader)\n",
    "    print(round(current_train_iou, 5), round(current_val_iou, 5), round(current_test_iou, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation of Traditional KD with fractional dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "percentage:  10\n",
      "model:  resnet10\n",
      "0.33353 0.25293 0.2306\n",
      "model:  resnet14\n",
      "0.36537 0.28911 0.247\n",
      "model:  resnet18\n",
      "0.35355 0.26334 0.24498\n",
      "model:  resnet20\n",
      "0.34641 0.23383 0.2309\n",
      "model:  resnet26\n",
      "0.35703 0.27662 0.24255\n",
      "percentage:  20\n",
      "model:  resnet10\n",
      "0.39709 0.3285 0.26558\n",
      "model:  resnet14\n",
      "0.38971 0.32868 0.27529\n",
      "model:  resnet18\n",
      "0.4023 0.33936 0.27953\n",
      "model:  resnet20\n",
      "0.4123 0.33812 0.2688\n",
      "model:  resnet26\n",
      "0.38988 0.3105 0.25078\n",
      "percentage:  30\n",
      "model:  resnet10\n",
      "0.44408 0.35192 0.2833\n",
      "model:  resnet14\n",
      "0.4574 0.38642 0.29293\n",
      "model:  resnet18\n",
      "0.42197 0.38238 0.27754\n",
      "model:  resnet20\n",
      "0.44397 0.39396 0.28697\n",
      "model:  resnet26\n",
      "0.43627 0.35542 0.28541\n",
      "percentage:  40\n",
      "model:  resnet10\n",
      "0.51495 0.42759 0.30368\n",
      "model:  resnet14\n",
      "0.52595 0.42193 0.30817\n",
      "model:  resnet18\n",
      "0.49122 0.42616 0.30294\n",
      "model:  resnet20\n",
      "0.46955 0.40827 0.30043\n",
      "model:  resnet26\n",
      "0.49713 0.43863 0.3088\n"
     ]
    }
   ],
   "source": [
    "for perc in [10, 20, 30, 40] :\n",
    "    x_train_dir = os.path.join(DATA_DIR, 'trainsmall' + str(perc))\n",
    "    y_train_dir = os.path.join(DATA_DIR, 'trainsmallannot' + str(perc))\n",
    "    train_dataset = CamVid(\n",
    "        x_train_dir,\n",
    "        y_train_dir,\n",
    "        classes = classes,\n",
    "        transform = transform\n",
    "    )\n",
    "    trainloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, drop_last = True)\n",
    "    print('percentage: ', perc)\n",
    "    for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'] : \n",
    "        train_ious = list()\n",
    "        val_ious = list()\n",
    "        test_ious = list()\n",
    "        unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "        unet.load_state_dict(torch.load('../saved_models/camvid/less_data' + str(perc) + '/trad_kd/' + model_name + '/model0.pt'))\n",
    "        current_val_iou = mean_iou(unet, valloader)\n",
    "        current_train_iou = mean_iou(unet, trainloader)\n",
    "        current_test_iou = mean_iou(unet, testloader)\n",
    "        print('model: ', model_name)\n",
    "        print(round(current_train_iou, 5), round(current_val_iou, 5), round(current_test_iou, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation of standalone training with fraction of data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "perc :  10\n",
      "model :  resnet10\n",
      "0.35172  +-  0.00724 ---- 0.25953  +-  0.01908 ---- 0.22803  +-  0.00616\n",
      "model :  resnet14\n",
      "0.35933  +-  0.00434 ---- 0.26859  +-  0.00857 ---- 0.23766  +-  0.00483\n",
      "model :  resnet18\n",
      "0.36886  +-  0.00521 ---- 0.26705  +-  0.02723 ---- 0.23669  +-  0.00859\n",
      "model :  resnet20\n",
      "0.35547  +-  0.01504 ---- 0.26502  +-  0.02425 ---- 0.24018  +-  0.00952\n",
      "model :  resnet26\n",
      "0.36588  +-  0.00516 ---- 0.26088  +-  0.01691 ---- 0.2421  +-  0.00417\n",
      "perc :  20\n",
      "model :  resnet10\n",
      "0.39674  +-  0.00505 ---- 0.31366  +-  0.00945 ---- 0.25953  +-  0.001\n",
      "model :  resnet14\n",
      "0.39985  +-  0.00263 ---- 0.3194  +-  0.00492 ---- 0.26274  +-  0.00404\n",
      "model :  resnet18\n",
      "0.40056  +-  0.01402 ---- 0.32341  +-  0.00545 ---- 0.26703  +-  0.00799\n",
      "model :  resnet20\n",
      "0.40694  +-  0.00893 ---- 0.33271  +-  0.01685 ---- 0.26685  +-  0.00318\n",
      "model :  resnet26\n",
      "0.39635  +-  0.00507 ---- 0.34098  +-  0.00898 ---- 0.27224  +-  0.00577\n",
      "perc :  30\n",
      "model :  resnet10\n",
      "0.42637  +-  0.00958 ---- 0.36693  +-  0.00476 ---- 0.28337  +-  0.00491\n",
      "model :  resnet14\n",
      "0.43292  +-  0.01892 ---- 0.3807  +-  0.01804 ---- 0.28652  +-  0.00829\n",
      "model :  resnet18\n",
      "0.42978  +-  0.00373 ---- 0.37438  +-  0.01485 ---- 0.28265  +-  0.0028\n",
      "model :  resnet20\n",
      "0.42031  +-  0.00523 ---- 0.36714  +-  0.00436 ---- 0.2809  +-  0.00098\n",
      "model :  resnet26\n",
      "0.40044  +-  0.03432 ---- 0.36584  +-  0.03175 ---- 0.27266  +-  0.01458\n",
      "perc :  40\n",
      "model :  resnet10\n",
      "0.47773  +-  0.00252 ---- 0.40054  +-  0.00193 ---- 0.29532  +-  0.00239\n",
      "model :  resnet14\n",
      "0.47876  +-  0.00967 ---- 0.41671  +-  0.00743 ---- 0.30024  +-  0.00431\n",
      "model :  resnet18\n",
      "0.47802  +-  0.00442 ---- 0.41912  +-  0.01252 ---- 0.2988  +-  0.00468\n",
      "model :  resnet20\n",
      "0.4702  +-  0.00266 ---- 0.41737  +-  0.00744 ---- 0.29893  +-  0.00245\n",
      "model :  resnet26\n",
      "0.47665  +-  0.01801 ---- 0.43587  +-  0.00916 ---- 0.30727  +-  0.0073\n"
     ]
    }
   ],
   "source": [
    "for perc in [10, 20, 30, 40] :\n",
    "    x_train_dir = os.path.join(DATA_DIR, 'trainsmall' + str(perc))\n",
    "    y_train_dir = os.path.join(DATA_DIR, 'trainsmallannot' + str(perc))\n",
    "    train_dataset = CamVid(\n",
    "        x_train_dir,\n",
    "        y_train_dir,\n",
    "        classes = classes,\n",
    "        transform = transform\n",
    "    )\n",
    "    trainloader = DataLoader(train_dataset, batch_size = 8, shuffle = True, drop_last = True)\n",
    "    print('perc : ', perc)\n",
    "    for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'] : \n",
    "        print('model : ', model_name)\n",
    "        train_ious = list()\n",
    "        val_ious = list()\n",
    "        test_ious = list()\n",
    "        unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "        for i in range(3) : \n",
    "            unet.load_state_dict(torch.load('../saved_models/camvid/less_data/' + model_name + '/pretrained_' + str(perc) + '_' + str(i) + '.pt'))\n",
    "            current_val_iou = mean_iou(unet, valloader)\n",
    "            current_train_iou = mean_iou(unet, trainloader)\n",
    "            current_test_iou = mean_iou(unet, testloader)\n",
    "            train_ious.append(current_train_iou)\n",
    "            val_ious.append(current_val_iou)\n",
    "            test_ious.append(current_test_iou)\n",
    "        train_mean = np.mean(train_ious)\n",
    "        train_std_dev = np.std(train_ious)\n",
    "        val_mean = np.mean(val_ious)\n",
    "        val_std_dev = np.std(val_ious)\n",
    "        test_mean = np.mean(test_ious)\n",
    "        test_std_dev = np.std(test_ious)\n",
    "        print(round(train_mean, 5), ' +- ', round(train_std_dev, 5), '----', round(val_mean, 5), ' +- ', round(val_std_dev, 5), '----', round(test_mean, 5), ' +- ', round(test_std_dev, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for perc in [10, 20, 30, 40] :\n",
    "    print('perc : ', perc)\n",
    "    for model_name in ['resnet10', 'resnet14', 'resnet18', 'resnet20', 'resnet26'] : \n",
    "        print('model : ', model_name)\n",
    "        train_ious = list()\n",
    "        val_ious = list()\n",
    "        test_ious = list()\n",
    "        unet = models.unet.Unet(model_name, classes = 12, encoder_weights = None).cuda()\n",
    "        unet.load_state_dict(torch.load('../saved_models/camvid/less_data' + str(perc) + '/simultaneous/' + model_name + '/model0.pt'))\n",
    "        current_val_iou = mean_iou(unet, valloader)\n",
    "        print(round(current_val_iou, 5))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (pyt)",
   "language": "python",
   "name": "pyt"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
